Add wandb support#3053
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
| if self.config.managed_mldiagnostics: | ||
| ManagedMLDiagnostics(config) # Initialize the MLRun instance. | ||
|
|
||
| self.enable_wandb = self.config.enable_wandb and socket.gethostname().endswith("-0") # you should only init wandb on one host. |
There was a problem hiding this comment.
Check jax process index instead.
|
Thanks @Zephyr271828, could you also share any pre-requisites to generate the WANDB_API_KEY? |
Thank you for your detailed review! I will modify the code accordingly. I don't think any prerequisite is needed to generate the API key? You may simply go to wandb.ai/settings#apikeys to generate an API key. Then you only need to set |
|
@Zephyr271828 Hi! |
|
@shralex @dipannita08 @gagika @kryvokhyzha Thank you for your support and detailed feedback! Below is a summary of the potential improvements you suggested:
Rank 0 detectionSee here. PerformanceDue to limited TPU resources I have, I tested the performance of training qwen3-0.6b from scratch w/ and w/o wandb on v4-16 (2 tpu vms) to simulate a basic multi-host training setup. command#!/bin/bash
set -euo pipefail
source get_tpu_bucket_name.sh
export TPU_PREFIX="$(get_tpu_name)"
export BUCKET_NAME="$(get_bucket_name)"
export NUM_HOSTS=$(get_num_hosts)
for arg in "$@"; do
case $arg in
--lr=*) LR="${arg#*=}" ;;
--batch_size=*) BATCH_SIZE="${arg#*=}" ;;
--global_batch_size=*) GLOBAL_BATCH_SIZE="${arg#*=}" ;;
--grad_clip=*) GRAD_CLIP="${arg#*=}" ;;
--min_lr_ratio=*) MIN_LR_RATIO="${arg#*=}" ;;
--warmup_ratio=*) WARMUP_RATIO="${arg#*=}" ;;
--max_to_keep=*) MAX_TO_KEEP="${arg#*=}" ;;
--data_files=*) DATA_FILES="${arg#*=}" ;;
--shuffle=*) SHUFFLE="${arg#*=}" ;;
--tag=*) TAG="${arg#*=}" ;;
*) echo "[WARN] Unknown arg $arg" ;;
esac
done
export MODEL_NAME="qwen3-0.6b"
export NUM_STEPS=50000
export SEQ_LEN=8192
export BATCH_SIZE=${BATCH_SIZE:-1}
export GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE:-64}
export GRAD_ACCUM=$((GLOBAL_BATCH_SIZE / BATCH_SIZE / NUM_HOSTS / 4))
export GRAD_CLIP=${GRAD_CLIP:-1.0}
export LR=${LR:-0.0003}
export MIN_LR_RATIO=${MIN_LR_RATIO:-0.1}
export WARMUP_RATIO=${WARMUP_RATIO:-0.05}
export ASYNC_CHECKPOINTING=false
export BASE_OUTPUT_DIRECTORY="gs://${BUCKET_NAME}/model_ckpts/maxtext"
export MAX_TO_KEEP=${MAX_TO_KEEP:-1}
export DATA_FILES="${DATA_FILES:-/home/zephyr/gcs-bucket/datasets/dclm/llama3_array_record_with_special_tokens_64/*.array_record}"
export SHUFFLE="${SHUFFLE:-True}"
export RUN_NAME="${MODEL_NAME}_L200_seqlen_${SEQ_LEN}_bs_${BATCH_SIZE}_grad_accum_${GRAD_ACCUM}_lr_${LR}_min_lr_ratio_${MIN_LR_RATIO}_warmup_ratio_${WARMUP_RATIO}"
if [ ! -z "${TAG:-}" ]; then
export RUN_NAME="${RUN_NAME}_${TAG}"
fi
export JAX_PLATFORMS=tpu
export SPARSE_MODEL_TRAINING=False
export PYTHONPATH=./src:${PYTHONPATH:-''}
python -u multihost_runner_orig.py \
--TPU_PREFIX=${TPU_PREFIX} \
--COMMAND="
export TPU_LOG_DIR=/home/zephyr/tpu_logs
export WANDB_API_KEY='7d11bbca76b3081b6bd1efbbcf1572aab26c5d56'
source ~/maxtext_env_py311/bin/activate
export PYTHONPATH=./src:\${PYTHONPATH:-''}
~/maxtext_env_py311/bin/python -u -m src.MaxText.train src/MaxText/configs/base.yml \
run_name=${RUN_NAME} \
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
dataset_type=grain \
grain_train_files=${DATA_FILES} \
grain_file_type='arrayrecord' \
grain_worker_count=1 \
enable_data_shuffling=${SHUFFLE} \
tokenize_train_data=False \
tokenize_eval_data=False \
max_target_length=${SEQ_LEN} \
async_checkpointing=${ASYNC_CHECKPOINTING} \
model_name=${MODEL_NAME} \
steps=${NUM_STEPS} \
per_device_batch_size=${BATCH_SIZE} \
gradient_accumulation_steps=${GRAD_ACCUM} \
gradient_clipping_threshold=${GRAD_CLIP} \
learning_rate=${LR} \
warmup_steps_fraction=${WARMUP_RATIO} \
checkpoint_period=500 \
enable_wandb=True \
wandb_project_name=llm_pruning \
wandb_run_name=${TPU_PREFIX}_${RUN_NAME} \
packing=false \
sharding_tolerance=0.5 \
"w/o wandb logsI0220 17:10:13.291728 139933146753024 max_utils.py:695] Total memory size: 17.8 GB, Output size: 6.7 GB, Temp size: 11.1 GB, Argument size: 6.7 GB, Host temp size: 0.0 GB.
Per train step:
Total TFLOPs: 419.07
split as 55.92% learnable weight flops and 44.08% attention flops
I0220 17:10:13.300330 139933146753024 metric_logger.py:298] number parameters: 0.596 billion
I0220 17:10:13.362068 139847774729792 grain_pool.py:367] Grain pool will use 1 processes.
I0220 17:10:13.366581 139847774729792 grain_pool.py:440] Grain pool will start child processes.
I0220 17:10:13.369205 139847774729792 grain_pool.py:448] Grain pool started all child processes.
2026-02-20 17:10:16.334197: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-20 17:10:16.379936: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-20 17:10:17.931806: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
PyTorch was not found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
2026-02-20 17:10:19.885891: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0220 17:10:30.075978 139933146753024 max_utils.py:654]
Memstats: After params initialized:
I0220 17:10:30.076180 139933146753024 max_utils.py:660] Using (GB) 6.7 / 30.75 (21.788618%) on TPU_0(process=0,(0,0,0,0))
I0220 17:10:30.076251 139933146753024 max_utils.py:660] Using (GB) 6.7 / 30.75 (21.788618%) on TPU_1(process=0,(1,0,0,0))
I0220 17:10:30.076312 139933146753024 max_utils.py:660] Using (GB) 6.7 / 30.75 (21.788618%) on TPU_2(process=0,(0,1,0,0))
I0220 17:10:30.076370 139933146753024 max_utils.py:660] Using (GB) 6.7 / 30.75 (21.788618%) on TPU_3(process=0,(1,1,0,0))
I0220 17:10:44.227162 139933146753024 metric_logger.py:194] completed step: 1, seconds: 16.716, TFLOP/s/device: 25.070, Tokens/s/device: 3920.633, total_weights: 524224, loss: 249.849
I0220 17:10:50.938067 139933146753024 metric_logger.py:194] completed step: 2, seconds: 0.478, TFLOP/s/device: 876.634, Tokens/s/device: 137092.270, total_weights: 524224, loss: 250.093
I0220 17:10:57.648830 139933146753024 metric_logger.py:194] completed step: 3, seconds: 13.688, TFLOP/s/device: 30.616, Tokens/s/device: 4787.852, total_weights: 524224, loss: 249.443
I0220 17:11:04.359475 139933146753024 metric_logger.py:194] completed step: 4, seconds: 6.711, TFLOP/s/device: 62.443, Tokens/s/device: 9765.214, total_weights: 524224, loss: 247.791
I0220 17:11:11.070475 139933146753024 metric_logger.py:194] completed step: 5, seconds: 6.710, TFLOP/s/device: 62.451, Tokens/s/device: 9766.462, total_weights: 524224, loss: 245.704
I0220 17:11:17.781303 139933146753024 metric_logger.py:194] completed step: 6, seconds: 6.713, TFLOP/s/device: 62.423, Tokens/s/device: 9761.967, total_weights: 524224, loss: 242.812
I0220 17:11:24.492303 139933146753024 metric_logger.py:194] completed step: 7, seconds: 6.709, TFLOP/s/device: 62.468, Tokens/s/device: 9769.032, total_weights: 524224, loss: 242.086
I0220 17:11:31.203217 139933146753024 metric_logger.py:194] completed step: 8, seconds: 6.711, TFLOP/s/device: 62.449, Tokens/s/device: 9766.048, total_weights: 524224, loss: 238.795
I0220 17:11:37.914210 139933146753024 metric_logger.py:194] completed step: 9, seconds: 6.711, TFLOP/s/device: 62.445, Tokens/s/device: 9765.406, total_weights: 524224, loss: 235.700
I0220 17:11:44.624889 139933146753024 metric_logger.py:194] completed step: 10, seconds: 6.711, TFLOP/s/device: 62.441, Tokens/s/device: 9764.869, total_weights: 524224, loss: 230.782
I0220 17:11:51.335738 139933146753024 metric_logger.py:194] completed step: 11, seconds: 6.711, TFLOP/s/device: 62.442, Tokens/s/device: 9765.028, total_weights: 524224, loss: 228.549
I0220 17:11:58.046651 139933146753024 metric_logger.py:194] completed step: 12, seconds: 6.710, TFLOP/s/device: 62.455, Tokens/s/device: 9766.943, total_weights: 524224, loss: 225.022
I0220 17:12:04.757413 139933146753024 metric_logger.py:194] completed step: 13, seconds: 6.711, TFLOP/s/device: 62.447, Tokens/s/device: 9765.717, total_weights: 524224, loss: 218.865
I0220 17:12:11.468343 139933146753024 metric_logger.py:194] completed step: 14, seconds: 6.711, TFLOP/s/device: 62.445, Tokens/s/device: 9765.496, total_weights: 524224, loss: 214.666
I0220 17:12:18.179343 139933146753024 metric_logger.py:194] completed step: 15, seconds: 6.710, TFLOP/s/device: 62.450, Tokens/s/device: 9766.202, total_weights: 524224, loss: 208.936
I0220 17:12:24.890058 139933146753024 metric_logger.py:194] completed step: 16, seconds: 6.711, TFLOP/s/device: 62.444, Tokens/s/device: 9765.237, total_weights: 524224, loss: 206.809
I0220 17:12:31.600906 139933146753024 metric_logger.py:194] completed step: 17, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.589, total_weights: 524224, loss: 201.072
I0220 17:12:38.311885 139933146753024 metric_logger.py:194] completed step: 18, seconds: 6.711, TFLOP/s/device: 62.449, Tokens/s/device: 9766.096, total_weights: 524224, loss: 195.263
I0220 17:12:45.023010 139933146753024 metric_logger.py:194] completed step: 19, seconds: 6.711, TFLOP/s/device: 62.444, Tokens/s/device: 9765.346, total_weights: 524224, loss: 191.112
I0220 17:12:51.733779 139933146753024 metric_logger.py:194] completed step: 20, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.543, total_weights: 524224, loss: 185.623
I0220 17:12:58.444595 139933146753024 metric_logger.py:194] completed step: 21, seconds: 6.712, TFLOP/s/device: 62.439, Tokens/s/device: 9764.521, total_weights: 524224, loss: 180.439
I0220 17:13:05.155681 139933146753024 metric_logger.py:194] completed step: 22, seconds: 6.710, TFLOP/s/device: 62.450, Tokens/s/device: 9766.285, total_weights: 524224, loss: 178.119
I0220 17:13:11.866602 139933146753024 metric_logger.py:194] completed step: 23, seconds: 6.711, TFLOP/s/device: 62.444, Tokens/s/device: 9765.330, total_weights: 524224, loss: 170.033
I0220 17:13:18.577209 139933146753024 metric_logger.py:194] completed step: 24, seconds: 6.711, TFLOP/s/device: 62.448, Tokens/s/device: 9765.981, total_weights: 524224, loss: 162.621
I0220 17:13:25.288255 139933146753024 metric_logger.py:194] completed step: 25, seconds: 6.713, TFLOP/s/device: 62.430, Tokens/s/device: 9763.065, total_weights: 524224, loss: 159.374
I0220 17:13:31.999057 139933146753024 metric_logger.py:194] completed step: 26, seconds: 6.709, TFLOP/s/device: 62.466, Tokens/s/device: 9768.813, total_weights: 524224, loss: 154.217
I0220 17:13:38.710025 139933146753024 metric_logger.py:194] completed step: 27, seconds: 6.711, TFLOP/s/device: 62.444, Tokens/s/device: 9765.294, total_weights: 524224, loss: 148.877
I0220 17:13:45.420726 139933146753024 metric_logger.py:194] completed step: 28, seconds: 6.711, TFLOP/s/device: 62.447, Tokens/s/device: 9765.813, total_weights: 524224, loss: 144.224
I0220 17:13:52.131641 139933146753024 metric_logger.py:194] completed step: 29, seconds: 6.711, TFLOP/s/device: 62.444, Tokens/s/device: 9765.349, total_weights: 524224, loss: 139.123
I0220 17:13:58.842322 139933146753024 metric_logger.py:194] completed step: 30, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.543, total_weights: 524224, loss: 135.702w/ wandb logsI0220 17:16:03.667288 140548037974016 max_utils.py:695] Total memory size: 17.8 GB, Output size: 6.7 GB, Temp size: 11.1 GB, Argument size: 6.7 GB, Host temp size: 0.0 GB.
wandb: [wandb.login()] Loaded credentials for https://api.wandb.ai from WANDB_API_KEY.
wandb: Currently logged in as: yx3038 (yx3038-new-york-university) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.25.0
wandb: Run data is saved locally in /home/zephyr/2026-02-20-17-14-43/wandb/run-20260220_171603-k0646bs7
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run yufeng-qw-v4-16_2_qwen3-0.6b_L200_seqlen_8192_bs_1_grad_accum_8_lr_0.0003_min_lr_ratio_0.1_warmup_ratio_0.05
wandb: ⭐️ View project at https://wandb.ai/yx3038-new-york-university/llm_pruning
wandb: 🚀 View run at https://wandb.ai/yx3038-new-york-university/llm_pruning/runs/k0646bs7
Per train step:
Total TFLOPs: 419.07
split as 55.92% learnable weight flops and 44.08% attention flops
I0220 17:16:04.903574 140548037974016 metric_logger.py:298] number parameters: 0.596 billion
I0220 17:16:04.963233 140462667576896 grain_pool.py:367] Grain pool will use 1 processes.
I0220 17:16:04.969832 140462667576896 grain_pool.py:440] Grain pool will start child processes.
I0220 17:16:04.972459 140462667576896 grain_pool.py:448] Grain pool started all child processes.
2026-02-20 17:16:08.003286: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-20 17:16:08.051091: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-20 17:16:09.607792: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
PyTorch was not found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
2026-02-20 17:16:11.587793: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0220 17:16:21.990377 140548037974016 max_utils.py:654]
Memstats: After params initialized:
I0220 17:16:21.990737 140548037974016 max_utils.py:660] Using (GB) 6.7 / 30.75 (21.788618%) on TPU_0(process=0,(0,0,0,0))
I0220 17:16:21.991039 140548037974016 max_utils.py:660] Using (GB) 6.7 / 30.75 (21.788618%) on TPU_1(process=0,(1,0,0,0))
I0220 17:16:21.991271 140548037974016 max_utils.py:660] Using (GB) 6.7 / 30.75 (21.788618%) on TPU_2(process=0,(0,1,0,0))
I0220 17:16:21.991589 140548037974016 max_utils.py:660] Using (GB) 6.7 / 30.75 (21.788618%) on TPU_3(process=0,(1,1,0,0))
I0220 17:16:34.658525 140548037974016 metric_logger.py:194] completed step: 1, seconds: 17.029, TFLOP/s/device: 24.609, Tokens/s/device: 3848.473, total_weights: 524224, loss: 249.849
I0220 17:16:41.369417 140548037974016 metric_logger.py:194] completed step: 2, seconds: 0.497, TFLOP/s/device: 843.702, Tokens/s/device: 131942.291, total_weights: 524224, loss: 250.093
I0220 17:16:48.080118 140548037974016 metric_logger.py:194] completed step: 3, seconds: 12.190, TFLOP/s/device: 34.378, Tokens/s/device: 5376.157, total_weights: 524224, loss: 249.443
I0220 17:16:54.790829 140548037974016 metric_logger.py:194] completed step: 4, seconds: 6.709, TFLOP/s/device: 62.462, Tokens/s/device: 9768.093, total_weights: 524224, loss: 247.791
I0220 17:17:01.501564 140548037974016 metric_logger.py:194] completed step: 5, seconds: 6.710, TFLOP/s/device: 62.458, Tokens/s/device: 9767.429, total_weights: 524224, loss: 245.704
I0220 17:17:08.212548 140548037974016 metric_logger.py:194] completed step: 6, seconds: 6.711, TFLOP/s/device: 62.443, Tokens/s/device: 9765.154, total_weights: 524224, loss: 242.812
I0220 17:17:14.923486 140548037974016 metric_logger.py:194] completed step: 7, seconds: 6.710, TFLOP/s/device: 62.451, Tokens/s/device: 9766.342, total_weights: 524224, loss: 242.086
I0220 17:17:21.636461 140548037974016 metric_logger.py:194] completed step: 8, seconds: 6.714, TFLOP/s/device: 62.421, Tokens/s/device: 9761.641, total_weights: 524224, loss: 238.795
I0220 17:17:28.345310 140548037974016 metric_logger.py:194] completed step: 9, seconds: 6.709, TFLOP/s/device: 62.468, Tokens/s/device: 9769.004, total_weights: 524224, loss: 235.700
I0220 17:17:35.056111 140548037974016 metric_logger.py:194] completed step: 10, seconds: 6.716, TFLOP/s/device: 62.401, Tokens/s/device: 9758.650, total_weights: 524224, loss: 230.782
I0220 17:17:41.766973 140548037974016 metric_logger.py:194] completed step: 11, seconds: 6.706, TFLOP/s/device: 62.489, Tokens/s/device: 9772.269, total_weights: 524224, loss: 228.549
I0220 17:17:48.477904 140548037974016 metric_logger.py:194] completed step: 12, seconds: 6.711, TFLOP/s/device: 62.445, Tokens/s/device: 9765.521, total_weights: 524224, loss: 225.022
I0220 17:17:55.188795 140548037974016 metric_logger.py:194] completed step: 13, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.652, total_weights: 524224, loss: 218.865
I0220 17:18:01.899731 140548037974016 metric_logger.py:194] completed step: 14, seconds: 6.710, TFLOP/s/device: 62.453, Tokens/s/device: 9766.703, total_weights: 524224, loss: 214.666
I0220 17:18:08.610642 140548037974016 metric_logger.py:194] completed step: 15, seconds: 6.711, TFLOP/s/device: 62.441, Tokens/s/device: 9764.824, total_weights: 524224, loss: 208.936
I0220 17:18:15.321358 140548037974016 metric_logger.py:194] completed step: 16, seconds: 6.712, TFLOP/s/device: 62.436, Tokens/s/device: 9764.086, total_weights: 524224, loss: 206.809
I0220 17:18:22.032296 140548037974016 metric_logger.py:194] completed step: 17, seconds: 6.712, TFLOP/s/device: 62.440, Tokens/s/device: 9764.665, total_weights: 524224, loss: 201.072
I0220 17:18:28.743230 140548037974016 metric_logger.py:194] completed step: 18, seconds: 6.708, TFLOP/s/device: 62.470, Tokens/s/device: 9769.367, total_weights: 524224, loss: 195.263
I0220 17:18:35.454278 140548037974016 metric_logger.py:194] completed step: 19, seconds: 6.712, TFLOP/s/device: 62.440, Tokens/s/device: 9764.703, total_weights: 524224, loss: 191.112
I0220 17:18:42.165038 140548037974016 metric_logger.py:194] completed step: 20, seconds: 6.711, TFLOP/s/device: 62.443, Tokens/s/device: 9765.207, total_weights: 524224, loss: 185.623
I0220 17:18:48.876030 140548037974016 metric_logger.py:194] completed step: 21, seconds: 6.710, TFLOP/s/device: 62.451, Tokens/s/device: 9766.468, total_weights: 524224, loss: 180.439
I0220 17:18:55.586965 140548037974016 metric_logger.py:194] completed step: 22, seconds: 6.711, TFLOP/s/device: 62.441, Tokens/s/device: 9764.811, total_weights: 524224, loss: 178.119
I0220 17:19:02.298099 140548037974016 metric_logger.py:194] completed step: 23, seconds: 6.715, TFLOP/s/device: 62.405, Tokens/s/device: 9759.255, total_weights: 524224, loss: 170.033
I0220 17:19:09.008782 140548037974016 metric_logger.py:194] completed step: 24, seconds: 6.706, TFLOP/s/device: 62.494, Tokens/s/device: 9773.089, total_weights: 524224, loss: 162.621
I0220 17:19:15.719572 140548037974016 metric_logger.py:194] completed step: 25, seconds: 6.717, TFLOP/s/device: 62.390, Tokens/s/device: 9756.924, total_weights: 524224, loss: 159.374
I0220 17:19:22.430507 140548037974016 metric_logger.py:194] completed step: 26, seconds: 6.706, TFLOP/s/device: 62.491, Tokens/s/device: 9772.699, total_weights: 524224, loss: 154.217
I0220 17:19:29.141228 140548037974016 metric_logger.py:194] completed step: 27, seconds: 6.710, TFLOP/s/device: 62.457, Tokens/s/device: 9767.282, total_weights: 524224, loss: 148.877
I0220 17:19:35.852080 140548037974016 metric_logger.py:194] completed step: 28, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.672, total_weights: 524224, loss: 144.224
I0220 17:19:42.562924 140548037974016 metric_logger.py:194] completed step: 29, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.540, total_weights: 524224, loss: 139.123
I0220 17:19:49.273759 140548037974016 metric_logger.py:194] completed step: 30, seconds: 6.711, TFLOP/s/device: 62.445, Tokens/s/device: 9765.534, total_weights: 524224, loss: 135.702Logging more metricsThe metrics I log closely follows the implementation for tensorboard logging (i.e., I'm logging exactly the same metrics as tensorboard logging). Below is a list of the currently supported metrics:
Please let me know if any other metrics should be added to this list. Personally I think we may add RL metrics in subsequent PRs. |
Additionally, if you want to reproduce the experiments I've done, you may want to go to this commit. |
|
Thank you @Zephyr271828 ! Will take another look today! |
gagika
left a comment
There was a problem hiding this comment.
overall looks good, could you please address the comments and rebase the code.
| if key not in valid_fields: | ||
| logger.warning("Ignoring invalid/unsupported field from YAML/CLI: %s", repr(key)) | ||
| raise ValueError(f"{key!r} not in {", ".join(map(repr, valid_fields))}.") | ||
| raise ValueError(f"{key!r} not in {', '.join(map(repr, valid_fields))}.") |
There was a problem hiding this comment.
MaxText uses " " for strings
| if self.enable_wandb: | ||
| self.write_metrics_to_wandb(metrics, step) |
There was a problem hiding this comment.
could you check here that jax.process_index() == 0?
e.g.
if self.enable_wandb and self.config.enable_wandb and jax.process_index() == 0:
self.write_metrics_to_wandb(metrics, step)
dipannita08
left a comment
There was a problem hiding this comment.
Thank you for the additional testing @Zephyr271828! This looks great, could you address last couple comments and rebase?
Sorry I did not see the previous comment. Will finish that by the end of this week:) |
Hi! @dipannita08 @gagika I just addressed the latest comments and rebased my commits. Thank you so much for your reviews! |
|
This PR has been automatically marked as stale because it has not had recent activity. It will be closed soon if no further activity occurs. Thank you for your contributions. |
|
This PR was closed because it has been inactive for a while. Please reopen it if you are still working on it. |
|
@Zephyr271828 it seems that this was automatically closed instead of merging :( could you please rebase ? |
|
@shralex Hi I just rebased the PR. Let me know if I need to change anything :) |
Description
This PR aims to implement #2434 and add wandb logging support to MaxText.
Implementation details
The implementation of wandb logging simply follows the style of other logging interfaces.
Initialization
Logging step
Usage
python -u -m src.MaxText.train src/MaxText/configs/base.yml \ ... enable_wandb=True \ wandb_project_name=xxx \ wandb_run_name=yyy \ ...Limitations
Currently this implementation does not support resuming from an existing wandb run. In order to resume, we need to first retrieve the
run_idfrom somewhere, then doIt makes sense to save the run_ids at some cache dir inside of the maxtext repo, but I don't know whether that's consistent with the design philosophy of this project.
Tests
Example training script:
Outputs:
Wandb output:

Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.